{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 02参数管理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:09.649068Z",
     "iopub.status.busy": "2023-08-18T07:01:09.648305Z",
     "iopub.status.idle": "2023-08-18T07:01:10.928992Z",
     "shell.execute_reply": "2023-08-18T07:01:10.927959Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0970],\n",
       "        [-0.0827]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))\n",
    "X = torch.rand(size=(2, 4))\n",
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.933865Z",
     "iopub.status.busy": "2023-08-18T07:01:10.933267Z",
     "iopub.status.idle": "2023-08-18T07:01:10.939922Z",
     "shell.execute_reply": "2023-08-18T07:01:10.938931Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OrderedDict([('weight', tensor([[-0.0427, -0.2939, -0.1894,  0.0220, -0.1709, -0.1522, -0.0334, -0.2263]])), ('bias', tensor([0.0887]))])\n"
     ]
    }
   ],
   "source": [
    "print(net[2].state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.945104Z",
     "iopub.status.busy": "2023-08-18T07:01:10.944250Z",
     "iopub.status.idle": "2023-08-18T07:01:10.951764Z",
     "shell.execute_reply": "2023-08-18T07:01:10.950790Z"
    },
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.parameter.Parameter'>\n",
      "Parameter containing:\n",
      "tensor([0.0887], requires_grad=True)\n",
      "tensor([0.0887])\n"
     ]
    }
   ],
   "source": [
    "print(type(net[2].bias))\n",
    "print(net[2].bias)\n",
    "print(net[2].bias.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.956378Z",
     "iopub.status.busy": "2023-08-18T07:01:10.955542Z",
     "iopub.status.idle": "2023-08-18T07:01:10.961810Z",
     "shell.execute_reply": "2023-08-18T07:01:10.960767Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[2].weight.grad == None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.966725Z",
     "iopub.status.busy": "2023-08-18T07:01:10.965969Z",
     "iopub.status.idle": "2023-08-18T07:01:10.972600Z",
     "shell.execute_reply": "2023-08-18T07:01:10.971655Z"
    },
    "origin_pos": 19,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))\n",
      "('0.weight', torch.Size([8, 4])) ('0.bias', torch.Size([8])) ('2.weight', torch.Size([1, 8])) ('2.bias', torch.Size([1]))\n"
     ]
    }
   ],
   "source": [
    "print(*[(name, param.shape) for name, param in net[0].named_parameters()])\n",
    "print(*[(name, param.shape) for name, param in net.named_parameters()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 21
   },
   "source": [
    "这为我们提供了另一种访问网络参数的方式，如下所示。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.977269Z",
     "iopub.status.busy": "2023-08-18T07:01:10.976623Z",
     "iopub.status.idle": "2023-08-18T07:01:10.983222Z",
     "shell.execute_reply": "2023-08-18T07:01:10.982309Z"
    },
    "origin_pos": 23,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0887])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.state_dict()['2.bias'].data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:10.988088Z",
     "iopub.status.busy": "2023-08-18T07:01:10.987352Z",
     "iopub.status.idle": "2023-08-18T07:01:10.998245Z",
     "shell.execute_reply": "2023-08-18T07:01:10.997197Z"
    },
    "origin_pos": 28,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2596],\n",
       "        [0.2596]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def block1():\n",
    "    return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                         nn.Linear(8, 4), nn.ReLU())\n",
    "\n",
    "def block2():\n",
    "    net = nn.Sequential()\n",
    "    for i in range(4):\n",
    "        # 在这里嵌套\n",
    "        net.add_module(f'block {i}', block1())\n",
    "    return net\n",
    "\n",
    "rgnet = nn.Sequential(block2(), nn.Linear(4, 1))\n",
    "rgnet(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.002889Z",
     "iopub.status.busy": "2023-08-18T07:01:11.002264Z",
     "iopub.status.idle": "2023-08-18T07:01:11.007643Z",
     "shell.execute_reply": "2023-08-18T07:01:11.006464Z"
    },
    "origin_pos": 33,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (block 0): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 1): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 2): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "    (block 3): Sequential(\n",
      "      (0): Linear(in_features=4, out_features=8, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=8, out_features=4, bias=True)\n",
      "      (3): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=4, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(rgnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.012522Z",
     "iopub.status.busy": "2023-08-18T07:01:11.011839Z",
     "iopub.status.idle": "2023-08-18T07:01:11.018508Z",
     "shell.execute_reply": "2023-08-18T07:01:11.017590Z"
    },
    "origin_pos": 37,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.1999, -0.4073, -0.1200, -0.2033, -0.1573,  0.3546, -0.2141, -0.2483])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rgnet[0][1][0].bias.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.023955Z",
     "iopub.status.busy": "2023-08-18T07:01:11.023046Z",
     "iopub.status.idle": "2023-08-18T07:01:11.033287Z",
     "shell.execute_reply": "2023-08-18T07:01:11.032096Z"
    },
    "origin_pos": 47,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.0214, -0.0015, -0.0100, -0.0058]), tensor(0.))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_normal(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.normal_(m.weight, mean=0, std=0.01)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_normal)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.038321Z",
     "iopub.status.busy": "2023-08-18T07:01:11.037607Z",
     "iopub.status.idle": "2023-08-18T07:01:11.049009Z",
     "shell.execute_reply": "2023-08-18T07:01:11.047793Z"
    },
    "origin_pos": 52,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1., 1.]), tensor(0.))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_constant(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 1)\n",
    "        nn.init.zeros_(m.bias)\n",
    "net.apply(init_constant)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.054335Z",
     "iopub.status.busy": "2023-08-18T07:01:11.053550Z",
     "iopub.status.idle": "2023-08-18T07:01:11.063215Z",
     "shell.execute_reply": "2023-08-18T07:01:11.062244Z"
    },
    "origin_pos": 57,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.5236,  0.0516, -0.3236,  0.3794])\n",
      "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def init_xavier(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(m.weight)\n",
    "def init_42(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        nn.init.constant_(m.weight, 42)\n",
    "\n",
    "net[0].apply(init_xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.068164Z",
     "iopub.status.busy": "2023-08-18T07:01:11.067460Z",
     "iopub.status.idle": "2023-08-18T07:01:11.079228Z",
     "shell.execute_reply": "2023-08-18T07:01:11.078069Z"
    },
    "origin_pos": 66,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init weight torch.Size([8, 4])\n",
      "Init weight torch.Size([1, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[5.4079, 9.3334, 5.0616, 8.3095],\n",
       "        [0.0000, 7.2788, -0.0000, -0.0000]], grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_init(m):\n",
    "    if type(m) == nn.Linear:\n",
    "        print(\"Init\", *[(name, param.shape)\n",
    "                        for name, param in m.named_parameters()][0])\n",
    "        nn.init.uniform_(m.weight, -10, 10)\n",
    "        m.weight.data *= m.weight.data.abs() >= 5\n",
    "\n",
    "net.apply(my_init)\n",
    "net[0].weight[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.084158Z",
     "iopub.status.busy": "2023-08-18T07:01:11.083416Z",
     "iopub.status.idle": "2023-08-18T07:01:11.092672Z",
     "shell.execute_reply": "2023-08-18T07:01:11.091537Z"
    },
    "origin_pos": 71,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([42.0000, 10.3334,  6.0616,  9.3095])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data[:] += 1\n",
    "net[0].weight.data[0, 0] = 42\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:01:11.097767Z",
     "iopub.status.busy": "2023-08-18T07:01:11.096948Z",
     "iopub.status.idle": "2023-08-18T07:01:11.108904Z",
     "shell.execute_reply": "2023-08-18T07:01:11.107763Z"
    },
    "origin_pos": 77,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([True, True, True, True, True, True, True, True])\n",
      "tensor([True, True, True, True, True, True, True, True])\n"
     ]
    }
   ],
   "source": [
    "# 我们需要给共享层一个名称，以便可以引用它的参数\n",
    "shared = nn.Linear(8, 8)\n",
    "net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    shared, nn.ReLU(),\n",
    "                    nn.Linear(8, 1))\n",
    "net(X)\n",
    "# 检查参数是否相同\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])\n",
    "net[2].weight.data[0, 0] = 100\n",
    "# 确保它们实际上是同一个对象，而不只是有相同的值\n",
    "print(net[2].weight.data[0] == net[4].weight.data[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
